#!/usr/bin/python

import argparse
import ast
import glob
import math
import re
import numpy as np
import pandas as pd
from bokeh.layouts import row, column
from bokeh.models import HoverTool, LegendItem, Legend, RangeSlider, Button
from bokeh.plotting import figure, show
from bokeh.palettes import Dark2_5 as palette, Bokeh
from bokeh.models import ColumnDataSource
from bokeh.models import CheckboxGroup, CustomJS
from bokeh.palettes import Viridis3
import itertools


def get_typed_value(value):
    if value == '-nan':
        return np.nan
    try:
        typed_value = ast.literal_eval(value)
    except:
        typed_value = value
    return typed_value


def parse_tag(df, x_key, y_key, tag_key):
    lines = []

    for tag in df[tag_key].unique():
        criterion = (df[tag_key] == tag)
        df1 = df[criterion]
        current_domain = []
        current_value = []
        current_error = []
        for x in df1[x_key].unique():
            y = df1[df1[x_key] == x][y_key].median()
            error = df1[df1[x_key] == x][y_key].std()
            if y is np.nan:
                continue
            if y == 0:
                continue
            current_domain.append(float(x))
            current_value.append(float(y))
            current_error.append(float(error))
        lines.append({'label': str(tag), 'x': current_domain, 'y': current_value, 'error': current_error})
    return lines


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='lct_parse_pcounter',
        description='Parse the log generated by pcounter',
        epilog='Designed for the Lightweight Communication Tool (LCT) Library')
    parser.add_argument('filename')
    parser.add_argument('-r', '--rank', type=int, default=0)
    args = parser.parse_args()
    filenames = glob.glob(args.filename)

    labels = ["rank", "ctx_name", "time", "counter_name", "total", "count", "ave", "min", "max"]
    rows = []
    for filename in filenames:
        with open(filename) as f:
            for line in f.readlines():
                line = line.strip()
                m = re.match("pcounter,trend,(\d+),(\w+),(\d+),(\w+),(\d+),(\d+),(\d+),(\d+),(\d+)", line)
                if m:
                    # data = m.groups()[0].split(",")
                    # if len(labels) != len(data):
                    #     continue
                    current_entry = dict()
                    for label, d in zip(labels, m.groups()):
                        current_entry[label] = get_typed_value(d)
                    rows.append(current_entry)
                    if len(rows) % 10000 == 0:
                        print("Processing..." + str(len(rows)))
    df = pd.DataFrame(rows, columns=list(rows[0].keys()))

    print("Apply filter...")
    df1_tmp = df[df.apply(lambda row:
                          row["rank"] == args.rank,
                          axis=1)]
    df1 = df1_tmp.copy()
    # df1["ctx_counter_name"] = df1_tmp["ctx_name"] + ":" + df1_tmp["counter_name"]
    df1["ctx_counter_name"] = df.apply(lambda row: str(row["ctx_name"]) + ":" + str(row["counter_name"]), axis=1)
    line_entries_all = parse_tag(df1, "time", "count", "ctx_counter_name")


    def create_plot(line_entries, title):
        p = figure(title=title, x_axis_label="time", y_axis_label="count", width=1200, height=600)
        p.x_range.only_visible = True
        p.y_range.only_visible = True

        lines = []
        for entry, color in zip(line_entries, itertools.cycle(palette)):
            line = p.line(x=entry["x"], y=entry["y"], legend_label=entry["label"], color=color, name=entry["label"])
            lines.append(line)

        legend = p.legend[0]
        legend.click_policy = "hide"
        legend.ncols = 2
        legend.nrows = math.ceil(len(legend.items) / legend.ncols)
        p.add_layout(legend, "right")

        # create hover tool
        hover = HoverTool(mode="vline", tooltips=[
            ("time", "$snap_x{0,0}"),
            ('counter', '$name: $snap_y'),
        ])
        hover.point_policy = 'snap_to_data'
        hover.line_policy = 'nearest'
        p.add_tools(hover)
        # Add button
        btn = Button(label='Hide All')
        cb = CustomJS(args=dict(fig=p, btn=btn)
                      ,code='''
                      if (btn.label=='Hide All'){
                          for (var i=0; i<fig.renderers.length; i++){
                                  fig.renderers[i].visible=false}
                          btn.label = 'Show All'
                          }
                      else {for (var i=0; i<fig.renderers.length; i++){
                              fig.renderers[i].visible=true}
                      btn.label = 'Hide All'}
                      ''')

        btn.js_on_click(cb)
        p = column([p, btn])

        return p

    print("Creating plot...")
    fig_cumu = create_plot(line_entries_all, title="Cumulative")

    for entry in line_entries_all:
        entry["y"] = list(np.diff(np.array([0, *entry["y"]])))
    fig_hist = create_plot(line_entries_all, title="Histogram")
    layout = row(fig_hist, fig_cumu)
    # show the results
    show(layout)
